import sys
sys.path.append("../../")
import torch
import torch.nn as nn
from pytorch_transformers import BertTokenizer
from torch.optim.adagrad import Adagrad
from torch.utils.data import Dataset, DataLoader
import re
from tqdm import tqdm

class GloveDataset(Dataset):
    """ 将一份英文语料切词，然后组织成可供 GloVe训练用的 共现矩阵 """
    REGEX_WORD = re.compile(r"\b[a-zA-Z]{2,}\b")

    def __init__(self, docs, min_word_occurences=10, oov_token='<oov>', window_size=5):
        # 数值化
        docs_tok = [self.REGEX_WORD.findall(doc.lower()) for doc in docs]  # docs tokenized
        word_counter = {w: c for w, c in Counter(w for d in docs_tok for w in d).items()
                        if c > min_word_occurences}  # 比 seq快一倍
        w2i = {oov_token: 0}
        docs_tok_id = [[w2i.setdefault(w, len(w2i)) if w in word_counter else 0
                        for w in doc] for doc in docs_tok]  # docs tokenized, in id
        self.w2i, self.i2w = w2i, seq(w2i.items()).order_by(lambda w_i: w_i[1]).smap(lambda w, i: w).to_list()
        self.n_words = len(w2i)  # 注意不是 len(word_counter), 否则缺个OOV, 越界

        # 统计共现矩阵
        comatrix = Counter()
        for words_id in tqdm(docs_tok_id, desc='docs2comtx'):
            for i, w1 in enumerate(words_id):  # 注意窗口限制
                for j, w2 in enumerate(words_id[i + 1: i + window_size], start=i + 1):
                    comatrix[(w1, w2)] += 1 / (j - i)

        # 从共现矩阵中提取训练样本: (中心词A的下标, 邻居词B的下标) -> A和B的"共现值"
        print('extracting (a_word, b_words, co_score) from comatrix')
        a_words, b_words, co_score = zip(*((left, right, x) for (left, right), x in comatrix.items()))
        self.L_words = torch.LongTensor(a_words)
        self.R_words = torch.LongTensor(b_words)
        self.Y = torch.FloatTensor(co_score)

    def __len__(self):
        return len(self.Y)

    def __getitem__(self, item):
        return self.L_words[item], self.R_words[item], self.Y[item]

class GloVe(nn.Module):
    def __init__(self, vocab_size, index2word, word2index, embed_size=300, y_max=100, alpha=0.75):
        super().__init__()
        self.embed_size = embed_size
        self.y_max, self.alpha = y_max, alpha
        self.emb1 = nn.Embedding(vocab_size, embed_size)  # 模型参数的shape依赖于训练数据; 放到 build_model_from_dataset()中初始化
        self.emb2 = nn.Embedding(vocab_size, embed_size)
        self.bias1 = nn.Embedding(vocab_size, 1)
        self.bias2 = nn.Embedding(vocab_size, 1)
        self.i2w = index2word
        self.w2i = word2index  # 同理，词表的实际赋值也发生在 build_model_from_dataset()中

    def fit(self, dataset: GloveDataset, lr=0.05, batch_size=512, n_epochs=3):
        self.build_model_from_dataset(dataset)  # 根据训练数据集来初始化模型参数
        optimizer = Adagrad(self.parameters(), lr=lr)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)  # 0进程方便调试
        for epoch in range(n_epochs):
            for batch_idx, (L, R, Y) in enumerate(dataloader):
                loss = self(L.cuda(), R.cuda(), Y.cuda())
                if batch_idx % 100 == 0:
                    print('epoch/batch %03d/%03d, loss = %6.3f', epoch + 1, batch_idx + 1, loss)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        print('[DONE] training model')

    def build_model_from_dataset(self, dataset: GloveDataset, std=0.01):
        get_var_by_shape = lambda *shape: Parameter(torch.randn(*shape).mul(std).cuda(), requires_grad=True)
        self.a_vecs = get_var_by_shape(dataset.n_words, self.embed_size)
        self.b_vecs = get_var_by_shape(dataset.n_words, self.embed_size)
        self.a_bias = get_var_by_shape(dataset.n_words, )
        self.b_bias = get_var_by_shape(dataset.n_words, )
        self.i2w, self.w2i = dataset.i2w, dataset.w2i

    def forward(self, L, R, Y):
        W = Y.div(self.y_max).pow(self.alpha).clamp_max(1.0)  # 根据"共现值Y"来确定样本权重, 即公式中的 f(X_{ij})
        pred = torch.einsum('nd,nd->n', self.a_vecs[L], self.b_vecs[R]) + self.a_bias[L] + self.b_bias[R]
        target = (Y + 1).log()  # 注意加一，避免 log(0)溢出
        return W @ mse_loss(pred, target, reduction='none')  # 注意 reduction='none'

    @property
    def embeddings(self):  # 返回中心向量 与 邻居向量 的和; 简单粗暴
        return self.a_vecs + self.b_vecs

    def show_vec_space(self, n_show_vecs=300):
        # 建议先PCA再TSNE; 如果直接TSNE会非常慢, 用metric='euclidean'也会非常慢
        embed_pca = PCA(n_components=4).fit_transform(self.embeddings[:n_show_vecs, :].cpu().detach().numpy())
        embed_tsne = TSNE(metric='euclidean', verbose=1, n_jobs=4).fit_transform(embed_pca)
        # 在Jupyter Notebook中要使用 %matplotlib inline
        fig, ax = plt.subplots(figsize=(20, 14))
        for idx in range(n_show_vecs):
            x, y = embed_tsne[idx, :]
            ax.scatter(x, y, color='steelblue')
            ax.annotate(self.i2w[idx], (x, y), alpha=0.7)


class CBoW(nn.Module):
    def __init__(self, bert_dir):
        self.tokenizer = BertTokenizer.from_pretrained(bert_dir)


    def trainIters(self, tr_loader, dev_loader, te_loader,
                   max_epochs=10, print_every=100, valid_every=1000,
                   learning_rate=2e-5, model_file="CBoW.pkl"
                   ):
        pass

